import pickle
import os

import cv2

from global_vars import BOUNDS, PIXEL_SIZE, CAMERA_CONFIG
import utils.transporter_utils as utils

import ipdb
st = ipdb.set_trace


def load_all(folder, n_demos):
    def load_field(root_path, field, fname, vis_path=None):
        # Load sample from files.
        path = os.path.join(root_path, field)
        data = pickle.load(open(os.path.join(path, fname), 'rb'))
        # if field == "color":
        # 	shape = data.shape
        # 	assert len(shape) == 5 and shape[1] == 3
        # 	new_img = data.transpose(1,2,0,3,4).reshape(480*3,shape[0]*640,3)
        # 	cv2.imwrite(vis_path+'/'+fname.split('-')[0]+".jpg",new_img[:,:,::-1])

        return data

    # Get filename and random seed used to initialize episode.
    path = os.path.join(folder, 'action')
    episode = []
    lan_dic = []
    gt_bbox = {}
    save_ebm = True
    for i, fname in enumerate(sorted(os.listdir(path))):
        if i == n_demos:
            break
        if '.pkl' in fname:
            # Load and vis data.
            vis_path = os.path.join(folder,'vis')
            if not os.path.exists(vis_path):
                os.makedirs(vis_path)
            ebm_path = os.path.join(folder,'ebm')
            if not os.path.exists(ebm_path):
                os.makedirs(ebm_path)
            color = load_field(folder, 'color', fname, vis_path)
            depth = load_field(folder, 'depth', fname)
            action = load_field(folder, 'action', fname)
            reward = load_field(folder, 'reward', fname)
            info = load_field(folder, 'info', fname)

            # Reconstruct episode.
            for i in range(len(action)):
                obs = {'color': color[i], 'depth': depth[i]} 
                if save_ebm:
                    img = get_image(obs)
                    cv2.imwrite(ebm_path+'/'+fname.split('-')[0]+"_{}.jpg".format(i),img[:,:,::-1].transpose(1, 0, 2))
                    gt_bbox[fname.split('-')[0]+"_{}.jpg".format(i)]=info[i]

                episode.append((obs, action[i], reward[i], info[i]))
                lan_dic.append(info[i]['lang_goal'])

    if save_ebm:
        outfile = open(ebm_path + f'/lang_relations_{n_demos}.pickle','wb')
        outfile2 = open(ebm_path + f'/gt_bbox_relations_{n_demos}.pickle','wb')
        pickle.dump(lan_dic,outfile)
        outfile.close()
        pickle.dump(gt_bbox,outfile2)
        outfile2.close()
    return episode


def get_image(obs, cam_config=None):
    """Stack color and height images image."""

    # if self.use_goal_image:
    #   colormap_g, heightmap_g = utils.get_fused_heightmap(goal, configs)
    #   goal_image = self.concatenate_c_h(colormap_g, heightmap_g)
    #   input_image = np.concatenate((input_image, goal_image), axis=2)
    #   assert input_image.shape[2] == 12, input_image.shape

    pix_size = PIXEL_SIZE
    cam_config = CAMERA_CONFIG
    bounds = BOUNDS
    # Get color and height maps from RGB-D images.
    cmap, _ = utils.get_fused_heightmap(obs, cam_config, bounds, pix_size)
    return cmap

# parse pickle
def parse_pickle():
    dir_list = os.listdir('/projects/""/ns_transporter_data/transporter_data_sep_100d_new/')
    relations = ['left', 'right', 'below', 'above']
    n_demos = 10
    is_relations = False
    for di in dir_list:
        if ('.pickle' not in di) and ('.zip' not in di) and ('detector' not in di) and ('letter' not in di) and ('org' not in di):
            skip = is_relations
            for rel in relations:
                if rel in di:
                    if is_relations:
                        skip = False
                    else:
                        skip = True
                    break
            if 'composition' in di:
                skip = True
            # else:
                # skip = True
            if skip:
                continue

            print(di)
            pkl_path = di
            print(pkl_path)
            data_path = os.path.join('/projects/""/ns_transporter_data/transporter_data_sep_100d_new/', pkl_path)
            load_all(data_path, n_demos=n_demos)



if __name__ == '__main__':

    parse_pickle()